{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 02_TransE_l2_AD_drug_repurposing\n",
    "#\n",
    "# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on January 8, 2023\n",
    "# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on March 11, 2023\n",
    "#\n",
    "# 该脚本展示了如何使用我们的预训练模型 (TransE_l2) 进行药物重定位 (Alzheimer's disease).\n",
    "#\n",
    "# 需要的包:\n",
    "#          csv\n",
    "#          numpy\n",
    "#          torch\n",
    "#\n",
    "# 需要的文件:\n",
    "#          ./prerequisites/infer_drug.tsv -> Drugbank 中的 FDA 批准的药物\n",
    "#          ./prerequisites/ad_drugs.txt\n",
    "#          ../../data/drkg/entities.tsv\n",
    "#          ../../data/drkg/relations.tsv\n",
    "#          ../01-model/ckpts/TransE_l2_All_DRKG_0/All_DRKG_TransE_l2_entity.npy\n",
    "#          ../01-model/ckpts/TransE_l2_All_DRKG_0/All_DRKG_TransE_l2_relation.npy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Alzheimer's disease Drug Repurposing via disease-compounds relations\n",
    "\n",
    "这个例子展示了如何使用我们的预训练模型 (TransE_l2)进行药物重定位."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collecting Alzheimer's disease\n",
    "\n",
    "一开始我们需要收集 DRKG 中的 AD 列表. 我们能够使用 DRKG 中的疾病 ID 编码疾病. 下面我们将全部的 AD 疾病作为目标. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:45.778730Z",
     "start_time": "2022-11-15T06:14:45.768734Z"
    }
   },
   "outputs": [],
   "source": [
    "AD_disease_list = [\n",
    "'Disease::DOID:10652',\n",
    "'Disease::MESH:C536599',\n",
    "'Disease::MESH:D000544'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:45.791673Z",
     "start_time": "2022-11-15T06:14:45.780703Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(AD_disease_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:45.798655Z",
     "start_time": "2022-11-15T06:14:45.793668Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Disease::DOID:10652', 'Disease::MESH:C536599', 'Disease::MESH:D000544']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "AD_disease_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Candidate drugs\n",
    "\n",
    "现在我们使用 Drugbank 中的 FDA 批准的药物作为候选药物.（我们排除分子量 < 250 的药物）药物清单在 infer\\_drug.tsv 中."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:45.823616Z",
     "start_time": "2022-11-15T06:14:45.800677Z"
    }
   },
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "# Load entity file\n",
    "drug_list = []\n",
    "with open(\"./prerequisites/infer_drug.tsv\", newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['drug','ids'])\n",
    "    for row_val in reader:\n",
    "        drug_list.append(row_val['drug'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:45.829572Z",
     "start_time": "2022-11-15T06:14:45.824585Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8104"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(drug_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:45.836563Z",
     "start_time": "2022-11-15T06:14:45.831576Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Compound::DB00605', 'Compound::DB00983', 'Compound::DB01240']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "drug_list[:3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Treatment relation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:45.845529Z",
     "start_time": "2022-11-15T06:14:45.838548Z"
    }
   },
   "outputs": [],
   "source": [
    "treatment_list = [\n",
    "'DRUGBANK::treats::Compound:Disease',\n",
    "'GNBR::T::Compound:Disease',\n",
    "'Hetionet::CtD::Compound:Disease'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:45.854506Z",
     "start_time": "2022-11-15T06:14:45.848528Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['DRUGBANK::treats::Compound:Disease',\n",
       " 'GNBR::T::Compound:Disease',\n",
       " 'Hetionet::CtD::Compound:Disease']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "treatment_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DRKG 中 AD 的治疗药物"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "ad_drugs = []\n",
    "\n",
    "with open(\"./prerequisites/ad_drugs.txt\", encoding='utf-8') as f:\n",
    "    for drug in f:\n",
    "        ad_drugs.append(drug[:-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "126"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(ad_drugs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Compound::DB13588',\n",
       " 'Compound::DB09130',\n",
       " 'Compound::DB00244',\n",
       " 'Compound::DB04815',\n",
       " 'Compound::DB03994',\n",
       " 'Compound::DB00712',\n",
       " 'Compound::DB00158',\n",
       " 'Compound::DB00472',\n",
       " 'Compound::DB00331',\n",
       " 'Compound::DB01593',\n",
       " 'Compound::DB11100',\n",
       " 'Compound::DB03128',\n",
       " 'Compound::DB06712',\n",
       " 'Compound::DB00993',\n",
       " 'Compound::DB00877',\n",
       " 'Compound::DB11094',\n",
       " 'Compound::DB04115',\n",
       " 'Compound::DB11805',\n",
       " 'Compound::DB00843',\n",
       " 'Compound::DB00928',\n",
       " 'Compound::DB11068',\n",
       " 'Compound::DB00694',\n",
       " 'Compound::DB11748',\n",
       " 'Compound::DB14500',\n",
       " 'Compound::DB00382',\n",
       " 'Compound::DB00393',\n",
       " 'Compound::DB00490',\n",
       " 'Compound::DB00215',\n",
       " 'Compound::DB03843',\n",
       " 'Compound::DB06756',\n",
       " 'Compound::DB05289',\n",
       " 'Compound::DB00564',\n",
       " 'Compound::DB05381',\n",
       " 'Compound::DB01235',\n",
       " 'Compound::DB11780',\n",
       " 'Compound::DB00945',\n",
       " 'Compound::DB00571',\n",
       " 'Compound::DB00674',\n",
       " 'Compound::DB01234',\n",
       " 'Compound::DB00134',\n",
       " 'Compound::DB08842',\n",
       " 'Compound::DB00136',\n",
       " 'Compound::DB00368',\n",
       " 'Compound::DB12310',\n",
       " 'Compound::DB00682',\n",
       " 'Compound::DB00459',\n",
       " 'Compound::DB03929',\n",
       " 'Compound::DB02543',\n",
       " 'Compound::DB03575',\n",
       " 'Compound::DB01050',\n",
       " 'Compound::DB01017',\n",
       " 'Compound::DB00533',\n",
       " 'Compound::DB00859',\n",
       " 'Compound::DB04703',\n",
       " 'Compound::DB15357',\n",
       " 'Compound::DB01645',\n",
       " 'Compound::DB06479',\n",
       " 'Compound::DB12052',\n",
       " 'Compound::DB03614',\n",
       " 'Compound::DB13668',\n",
       " 'Compound::DB13084',\n",
       " 'Compound::DB00422',\n",
       " 'Compound::DB00809',\n",
       " 'Compound::DB11336',\n",
       " 'Compound::DB12148',\n",
       " 'Compound::DB14086',\n",
       " 'Compound::DB06750',\n",
       " 'Compound::DB01403',\n",
       " 'Compound::DB01611',\n",
       " 'Compound::DB01200',\n",
       " 'Compound::DB02701',\n",
       " 'Compound::DB00981',\n",
       " 'Compound::DB04557',\n",
       " 'Compound::DB00746',\n",
       " 'Compound::DB14128',\n",
       " 'Compound::DB00126',\n",
       " 'Compound::DB03467',\n",
       " 'Compound::DB04422',\n",
       " 'Compound::DB01041',\n",
       " 'Compound::DB00115',\n",
       " 'Compound::DB04864',\n",
       " 'Compound::DB01104',\n",
       " 'Compound::DB15579',\n",
       " 'Compound::DB00635',\n",
       " 'Compound::DB09163',\n",
       " 'Compound::DB13178',\n",
       " 'Compound::DB09149',\n",
       " 'Compound::DB12116',\n",
       " 'Compound::DB13132',\n",
       " 'Compound::DB01796',\n",
       " 'Compound::DB00734',\n",
       " 'Compound::DB01394',\n",
       " 'Compound::DB00091',\n",
       " 'Compound::DB04892',\n",
       " 'Compound::DB00325',\n",
       " 'Compound::DB01043',\n",
       " 'Compound::DB09151',\n",
       " 'Compound::DB01037',\n",
       " 'Compound::DB01065',\n",
       " 'Compound::DB11473',\n",
       " 'Compound::DB00641',\n",
       " 'Compound::DB00328',\n",
       " 'Compound::DB01381',\n",
       " 'Compound::DB04743',\n",
       " 'Compound::DB00989',\n",
       " 'Compound::DB00337',\n",
       " 'Compound::DB09148',\n",
       " 'Compound::DB01370',\n",
       " 'Compound::DB00656',\n",
       " 'Compound::DB04660',\n",
       " 'Compound::DB04365',\n",
       " 'Compound::DB01219',\n",
       " 'Compound::DB11797',\n",
       " 'Compound::DB01698',\n",
       " 'Compound::DB00987',\n",
       " 'Compound::DB00741',\n",
       " 'Compound::DB06527',\n",
       " 'Compound::DB01750',\n",
       " 'Compound::DB11725',\n",
       " 'Compound::DB01392',\n",
       " 'Compound::DB09157',\n",
       " 'Compound::DB00122',\n",
       " 'Compound::DB11672',\n",
       " 'Compound::DB02530',\n",
       " 'Compound::DB00313',\n",
       " 'Compound::DB00207']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ad_drugs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get pretrained model\n",
    "\n",
    "我们能直接使用我们的预训练模型 (TransE_l2)做药物重定位."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:46.380107Z",
     "start_time": "2022-11-15T06:14:46.377108Z"
    }
   },
   "outputs": [],
   "source": [
    "entity_id_file = '../../data/drkg/entities.tsv'\n",
    "relation_id_file = '../../data/drkg/relations.tsv'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get embeddings for diseases and drugs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:46.647948Z",
     "start_time": "2022-11-15T06:14:46.382095Z"
    }
   },
   "outputs": [],
   "source": [
    "# Get drugname/disease name to entity ID mappings\n",
    "entity_map = {}\n",
    "entity_id_map = {}\n",
    "relation_map = {}\n",
    "with open(entity_id_file, newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['id', 'name'])\n",
    "    for row_val in reader:\n",
    "        entity_map[row_val['name']] = int(row_val['id'])\n",
    "        entity_id_map[int(row_val['id'])] = row_val['name']\n",
    "        \n",
    "with open(relation_id_file, newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['id', 'name'])\n",
    "    for row_val in reader:\n",
    "        relation_map[row_val['name']] = int(row_val['id'])\n",
    "        \n",
    "# handle the ID mapping\n",
    "drug_ids = []\n",
    "disease_ids = []\n",
    "for drug in drug_list:\n",
    "    drug_ids.append(entity_map[drug])\n",
    "    \n",
    "for disease in AD_disease_list:\n",
    "    disease_ids.append(entity_map[disease])\n",
    "\n",
    "treatment_ids = [relation_map[treat]  for treat in treatment_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:46.654929Z",
     "start_time": "2022-11-15T06:14:46.649941Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 8104, 3)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(disease_ids),len(drug_ids),len(treatment_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:46.661914Z",
     "start_time": "2022-11-15T06:14:46.655927Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([83488, 45467, 25842], [9475, 11010, 7486], [24, 35, 68])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "disease_ids, drug_ids[:3], treatment_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:47.192981Z",
     "start_time": "2022-11-15T06:14:46.663905Z"
    }
   },
   "outputs": [],
   "source": [
    "# Load embeddings\n",
    "import torch\n",
    "import numpy as np\n",
    "entity_emb = np.load('../01-model/ckpts/TransE_l2_All_DRKG_0/All_DRKG_TransE_l2_entity.npy')\n",
    "rel_emb = np.load('../01-model/ckpts/TransE_l2_All_DRKG_0/All_DRKG_TransE_l2_relation.npy')\n",
    "\n",
    "drug_ids = torch.tensor(drug_ids).long()\n",
    "disease_ids = torch.tensor(disease_ids).long()\n",
    "treatment_ids = torch.tensor(treatment_ids)\n",
    "\n",
    "drug_emb = torch.tensor(entity_emb[drug_ids])\n",
    "treatment_embs = [torch.tensor(rel_emb[r_id]) for r_id in treatment_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:47.200961Z",
     "start_time": "2022-11-15T06:14:47.193979Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([83488, 45467, 25842]),\n",
       " tensor([ 9475, 11010,  7486]),\n",
       " tensor([24, 35, 68]))"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "disease_ids, drug_ids[:3], treatment_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:47.208939Z",
     "start_time": "2022-11-15T06:14:47.202956Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8104, 400])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "drug_emb.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Drug Repurposing Based on Edge Score\n",
    "\n",
    "我们使用下面算法来计算 the edge score. 注意, 这里我们用 logsigmoid 函数使全部 scores < 0. 分数越大, $(h, r, t)$ 越可能成立.\n",
    "\n",
    "$\\mathbf{d} = \\gamma - ||\\mathbf{h}+\\mathbf{r}-\\mathbf{t}||_{2}$\n",
    "\n",
    "$\\mathbf{score} = \\log\\left(\\frac{1}{1+\\exp(\\mathbf{-d})}\\right)$\n",
    "\n",
    "在进行药物重定位时，我们只使用与治疗相关的关系."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:48.344902Z",
     "start_time": "2022-11-15T06:14:47.210934Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch.nn.functional as fn\n",
    "\n",
    "gamma = 12.0\n",
    "def transE_l2(head, rel, tail):\n",
    "    score = head + rel - tail\n",
    "    return gamma - torch.norm(score, p=2, dim=-1)\n",
    "\n",
    "scores_per_disease = []\n",
    "drugs = []\n",
    "for r_id in range(len(treatment_embs)):\n",
    "    treatment_emb = treatment_embs[r_id]\n",
    "    for disease_id in disease_ids:\n",
    "        disease_emb = entity_emb[disease_id]\n",
    "        score = fn.logsigmoid(transE_l2(drug_emb, treatment_emb, disease_emb))\n",
    "        scores_per_disease.append(score)\n",
    "        drugs.append(drug_ids)\n",
    "scores = torch.cat(scores_per_disease)\n",
    "drugs = torch.cat(drugs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:48.351884Z",
     "start_time": "2022-11-15T06:14:48.345899Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([72936]), torch.Size([72936]), 72936)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores.shape, drugs.shape, 3*3*8104"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:48.445635Z",
     "start_time": "2022-11-15T06:14:48.353879Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((72936,), (72936,), 72936)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# sort scores in decending order\n",
    "# torch.flip: Reverse the order of a n-D tensor along given axis in dims.\n",
    "idx = torch.flip(torch.argsort(scores), dims=[0])\n",
    "scores = scores[idx].numpy()\n",
    "drugs = drugs[idx].numpy()\n",
    "\n",
    "scores.shape, drugs.shape, 3*3*8104"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Now we output proposed treatments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:48.562320Z",
     "start_time": "2022-11-15T06:14:48.447628Z"
    }
   },
   "outputs": [],
   "source": [
    "_, unique_indices = np.unique(drugs, return_index=True)\n",
    "topk = 50\n",
    "topk_indices = np.sort(unique_indices)[:topk]\n",
    "proposed_drugs = drugs[topk_indices]\n",
    "proposed_scores = scores[topk_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-11-15T06:14:48.572294Z",
     "start_time": "2022-11-15T06:14:48.563318Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[6]\tCompound::DB04540\t-0.122935451567173\n",
      "[7]\tCompound::DB09341\t-0.12512117624282837\n",
      "[8]\tCompound::DB00143\t-0.12561222910881042\n",
      "[10]\tCompound::DB00515\t-0.1341310441493988\n",
      "[14]\tCompound::DB00997\t-0.14285585284233093\n",
      "[15]\tCompound::DB00171\t-0.1469261348247528\n",
      "[16]\tCompound::DB01229\t-0.14909186959266663\n",
      "[17]\tCompound::DB00477\t-0.15069124102592468\n",
      "[19]\tCompound::DB00755\t-0.15511256456375122\n",
      "[24]\tCompound::DB00502\t-0.16119839251041412\n",
      "[26]\tCompound::DB00783\t-0.1685544103384018\n",
      "[27]\tCompound::DB00295\t-0.16982416808605194\n",
      "[28]\tCompound::DB00661\t-0.17014111578464508\n",
      "[34]\tCompound::DB00675\t-0.17270652949810028\n",
      "[35]\tCompound::DB00624\t-0.17352452874183655\n",
      "[36]\tCompound::DB00363\t-0.17481908202171326\n",
      "[37]\tCompound::DB12153\t-0.17557311058044434\n",
      "[39]\tCompound::DB01708\t-0.17726033926010132\n",
      "[40]\tCompound::DB00541\t-0.17793041467666626\n",
      "[41]\tCompound::DB00959\t-0.179716095328331\n",
      "[44]\tCompound::DB00396\t-0.18199630081653595\n",
      "[45]\tCompound::DB00907\t-0.18305420875549316\n",
      "[46]\tCompound::DB04216\t-0.18521440029144287\n",
      "[50]\tCompound::DB00531\t-0.19235941767692566\n"
     ]
    }
   ],
   "source": [
    "top50_list = []\n",
    "\n",
    "for i in range(topk):\n",
    "    drug = int(proposed_drugs[i])\n",
    "    if entity_id_map[drug] not in ad_drugs:\n",
    "        score = proposed_scores[i]\n",
    "        row = \"[{}],{},{}\\n\".format(i+1, entity_id_map[drug], score)\n",
    "        top50_list.append(row)\n",
    "        print(\"[{}]\\t{}\\t{}\".format(i+1, entity_id_map[drug], score))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 输出结果到文件\n",
    "\n",
    "f = open('./results/02_transE_l2_top50.csv', 'w')\n",
    "for row in top50_list:\n",
    "    f.write(row)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
